from AL_Dataset import ActiveLearning_Framework
import torch
from torch.nn import Module
from torch.nn import functional as F
from tqdm import tqdm
import utils
import numpy as np
from sklearn.metrics import pairwise_distances
from sklearn import svm
import math
import pdb
from scipy import stats
from copy import deepcopy as deepcopy
from matplotlib import pyplot as plt

'''
Baseline Active Learning Methods: 
1. Random Sampling
2. Max Entropy
3. K-Center
4. BADGE
''' 
class RandomSampling:
    def __init__(self, ALset:ActiveLearning_Framework, num_queried_each_round):
        """
        use random sammpling as query strategy, return a subset of random sampled data
        :param: ALset: (class: ActiveLearning_Framework)
        :param: num_queried_each_round: (int) number of queried data each AL iteration
        """
        self.ALset = ALset
        self.num_queried = num_queried_each_round

    def query(self):
        n = len(self.ALset.unlabeled_idx)
        data_selected = torch.randperm(n).tolist()[:self.num_queried]
        idx_selected = [self.ALset.unlabeled_idx[i] for i in data_selected]
        self.ALset.Update_AL_Datapool(idx_selected)


class MaxEntropy:
    def __init__(self, model: Module, ALset:ActiveLearning_Framework, num_queried_each_round, batch_size, device):
        """
        use Max Entropy as query strategy, return a subset of data with top entropy score
        :param: model: (nn.Module) current model
        :param: ALset: (class: ActiveLearning_Framework)
        :param: num_queried_each_round: (int) number of queried data each AL iteration
        :param: batch_size: (int) dataloader batchsize
        :param: device: (torch.device) GPU/CPU 
        """
        self.model = model
        self.ALset = ALset
        self.num_queried = num_queried_each_round
        self.batch_size = batch_size
        self.device = device
        unlabeled_pool = self.ALset.get_unlabeled_dataset()
        self.entropy_array = np.zeros(len(unlabeled_pool))
        self.dataloader = torch.utils.data.DataLoader(unlabeled_pool, batch_size=self.batch_size, shuffle=False)

    def query(self):
        self.model.eval()
        print('calculating entropy...')
        pointer = 0
        with torch.no_grad():
            for data in tqdm(self.dataloader):
                images, labels = data[0].to(self.device), data[1].to(self.device)
                outputs, _, _ = self.model(images)
                entropy = utils.expected_cross_entropy(outputs, outputs)
                self.entropy_array[pointer: pointer + len(entropy)] = entropy.cpu()
                pointer += len(entropy)
        data_selected = self.entropy_array.argsort()[-self.num_queried:]
        idx_selected = [self.ALset.unlabeled_idx[i] for i in data_selected]
        self.ALset.Update_AL_Datapool(idx_selected)


class kCenterGreedy:   # reference https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
    def __init__(self, model:Module, ALset:ActiveLearning_Framework, num_queried_each_round, batch_size, device, num_class, metric='euclidean'):
        """
        use KCenter greedy as query strategy
        :param: model: (nn.Module) current model
        :param: ALset: (class: ActiveLearning_Framework)
        :param: num_queried_each_round: (int) number of queried data each AL iteration
        :param: batch_size: (int) dataloader batchsize
        :param: device: (torch.device) GPU/CPU
        :param: num_class: (int) number of image class in dataset
        :param: metric: (str) metric for measuring pariwise distance
        """
        self.model = model
        self.ALset = ALset
        self.num_queried = num_queried_each_round
        self.batch_size = batch_size
        self.device = device
        self.num_class = num_class
        self.metric = metric
        self.min_distances = None
        self.features = None
        self.already_selected = None
        self.labeled_pool = self.ALset.get_train_dataset()
        self.dataloader_labeled = torch.utils.data.DataLoader(self.labeled_pool, batch_size=1, shuffle=False)
        self.unlabeled_pool = self.ALset.get_unlabeled_dataset()
        self.dataloader_unlabeled = torch.utils.data.DataLoader(self.unlabeled_pool, batch_size=1, shuffle=False)

    def update_distances(self, new_center_idx, only_new=True, reset_dist=False):
        """
        update the pairwise distance for all data samples after selecting new centers
        :param: new_center_idx: (list) selected new centers
        :param: only_new: (bool) False for initially labeled data in the first iteration, True otherwise
        :param: reset_dist: (bool) True for initially labeled data in the first iteration, False otherwise
        """
        if reset_dist:
            self.min_distances = None
        if only_new:
            new_center_idx = [d for d in new_center_idx if d not in self.already_selected]
        if new_center_idx:
            # Update min_distances for all examples given one new cluster center.
            if reset_dist:
                new_center = self.features[new_center_idx,:]
            else:
                new_center = self.features[new_center_idx,:]

            dist = pairwise_distances(self.features, new_center, metric=self.metric)
            if self.min_distances is None:
                self.min_distances = np.min(dist, axis=1).reshape(-1, 1)
            else:   # if distance to the new center is closer than previous min distance
                self.min_distances = np.minimum(self.min_distances, dist)   # O(n) time complexity

    def query(self):
        self.model.eval()
        with torch.no_grad():
            print('calculating image features...\n')
            for i, data in enumerate(self.dataloader_labeled):
                images, labels = data[0].to(self.device), data[1].to(self.device)
                outputs, feature, _ = self.model(images)
                if self.features is None:
                    self.features = np.zeros((len(self.labeled_pool)+len(self.unlabeled_pool), feature.shape[1]))
                self.features[i,:] = feature.ravel().cpu()  
            self.already_selected = list(range(len(self.labeled_pool)))

            for i, data in enumerate(self.dataloader_unlabeled):
                images, labels = data[0].to(self.device), data[1].to(self.device)
                outputs, feature, _ = self.model(images)
                self.features[i+len(self.labeled_pool),:] = feature.ravel().cpu()

            print('starting k-center algorithm...\n')
            self.update_distances(self.already_selected, only_new=False, reset_dist=True)
            data_selected = []
            while (len(data_selected) < self.num_queried):
                if self.min_distances is None:
                    # Initialize centers with a randomly selected datapoint
                    new_center_idx = np.random.choice(np.arange((self.features).shape[0]))
                else:
                    new_center_idx = np.argmax(self.min_distances)
                # New examples should not be in already selected since those points
                # should have min_distance of zero to a cluster center.
                assert new_center_idx not in self.already_selected

                self.update_distances([new_center_idx], only_new=True, reset_dist=False)
                data_selected.append(new_center_idx)
                self.already_selected.append(new_center_idx)
                if len(data_selected) % 200 == 0:
                    print('number of data queried:', len(data_selected), '/', self.num_queried, '\n')  
        idx_selected = [self.ALset.unlabeled_idx[i-len(self.labeled_pool)] for i in data_selected]
        self.ALset.Update_AL_Datapool(idx_selected)


class BADGE:    # reference：https://github.com/JordanAsh/badge
    def __init__(self, model: Module, ALset:ActiveLearning_Framework, num_queried_each_round, batch_size, device, num_class):
        """
        use BADGE as query strategy
        :param: model: (nn.Module) current model
        :param: ALset: (class: ActiveLearning_Framework)
        :param: num_queried_each_round: (int) number of queried data each AL iteration
        :param: batch_size: (int) dataloader batchsize
        :param: device: (torch.device) GPU/CPU
        :param: num_class: (int) number of image class in dataset
        """
        self.model = model
        self.ALset = ALset
        self.num_queried = num_queried_each_round
        self.batch_size = batch_size
        self.device = device
        self.num_class = num_class
        self.unlabeled_pool = self.ALset.get_unlabeled_dataset_AL()
        self.dataloader = torch.utils.data.DataLoader(self.unlabeled_pool, batch_size=self.batch_size, shuffle=False)

    def query(self):
        gradEmbedding = self.get_grad_embedding().numpy()
        data_selected = self.init_centers(gradEmbedding, self.num_queried)
        idx_selected = [self.ALset.unlabeled_idx[i] for i in data_selected]
        self.ALset.Update_AL_Datapool(idx_selected)

    def get_grad_embedding(self):
        """
        compute the gradient embedding for each unlabeled data, return the embedding as a torch tensor
        """
        embDim = self.model.get_embedding_dim()
        embedding = np.zeros([len(self.unlabeled_pool), embDim * self.num_class])
        self.model.eval()
        with torch.no_grad():
            pointer = 0
            for data in tqdm(self.dataloader):
                images, labels = data[0].to(self.device), data[1].to(self.device)
                outputs, _, out = self.model(images)
                out = out.data.cpu().numpy()
                batchProbs = F.softmax(outputs, dim=1).data.cpu().numpy()
                maxInds = np.argmax(batchProbs,1)
                for j in range(len(labels)):
                    for c in range(self.num_class):
                        if c == maxInds[j]:
                            embedding[pointer+j][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (1 - batchProbs[j][c])
                        else:
                            embedding[pointer+j][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (-1 * batchProbs[j][c])
                pointer = pointer+len(labels)
        return torch.Tensor(embedding)

    # kmeans ++ initialization
    def init_centers(self,X, K):
        """
        Kmeans++ algorithm, return clustering centers
        :param: X: (torch tensor) input data
        :param: K: (int) number of clustering centers
        """
        ind = np.argmax([np.linalg.norm(s, 2) for s in X])
        mu = [X[ind]]
        indsAll = [ind]
        centInds = [0.] * len(X)
        cent = 0
        print('#Samps\tTotal Distance')
        while len(mu) < K:
            if len(mu) == 1:
                D2 = pairwise_distances(X, mu).ravel().astype(float)
            else:
                newD = pairwise_distances(X, [mu[-1]]).ravel().astype(float)
                for i in range(len(X)):
                    if D2[i] >  newD[i]:
                        centInds[i] = cent
                        D2[i] = newD[i]
            if len(mu)%100==0:
                print(str(len(mu)) + '\t' + str(sum(D2)), flush=True)
            if sum(D2) == 0.0: pdb.set_trace()
            D2 = D2.ravel().astype(float)
            Ddist = (D2 ** 2)/ sum(D2 ** 2)
            customDist = stats.rv_discrete(name='custm', values=(np.arange(len(D2)), Ddist))
            ind = customDist.rvs(size=1)[0]
            while ind in indsAll: ind = customDist.rvs(size=1)[0]
            mu.append(X[ind])
            indsAll.append(ind)
            cent += 1
        return indsAll













        